import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms, models
from torch.utils.data import DataLoader
from torchvision.models import resnet18, resnet34, resnet50, resnet101, resnet152, vgg11, vgg13, vgg16, vgg19, mobilenet_v2, mobilenet_v3_large
from torchvision.models.shufflenetv2 import shufflenet_v2_x1_0

# from network_NN import DualBranchNet, adjust_layer_sizes_to_128x128, propocess_fun
from network_NN import DualBranchNet, propocess_fun

#-------------------------------------------------
import sys
# sys.path.append('/public/chenhaozhe/4_model_robust_hash/code/shrinkbench_CHZ_new')
# import models
from torchvision import models
import torch
import os
# os.environ['DATAPATH'] = '/data-x/g10/chzchen/CV_datasets'
# os.environ['WEIGHTSPATH'] = '/public/chenhaozhe/4_model_robust_hash/code/shrinkbench_CHZ_new/pretrained/'
from IPython.display import clear_output
import pdb
import numpy as np
import pandas as pd
import time
import torch
import timm
import traceback
import random
from random import choice
import copy

import torch.nn.utils.prune as prune

time_all = 0
strategy = 'GlobalMagWeight'

import numpy as np
import glob

from tqdm import tqdm
import math

# from vam_net import VAM

from network_NN import phi_i, T_i, compute_h_i

#--------------------------------------------------------------
# import torch
# from torch import nn
from torch.nn import functional as F
# from torch import optim

import random

# model_name_list = ['resnet18','resnet34'] # Example subset

model_name_list = ['resnet18', 'resnet34','resnet50','resnet101','resnet152','vgg11', 'vgg13', 'vgg16', 'vgg19', 'shufflenet', 'mobilenetv2'] # Example subset
model_dataset_list = ['CIFAR10','CIFAR100','MNIST','FashionMNIST']

def get_random_index(lst, a):
    # Ensure 'a' is a valid index
    if a < 0 or a >= len(lst):
        return "Index 'a' is out of range."
    
    # Get all indices except 'a'
    indices_except_a = [i for i in range(len(lst)) if i != a]
    
    # Return a random index from the list of indices except 'a'
    return random.choice(indices_except_a)

# Function to apply L1 unstructured pruning to all layers with weights
def prune_model(model, amount=0.5):
    for name, module in model.named_modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)):  # Check for conv and linear layers
            prune.l1_unstructured(module, name='weight', amount=amount)
            # print(f'Pruned {name} weights')


def get_model(model_name):
    if model_name.startswith('resnet'):
        return getattr(models, model_name)(pretrained=False)
    elif model_name.startswith('vgg'):
        return getattr(models, model_name)(pretrained=False)
    elif model_name == 'shufflenet':
        return models.shufflenet_v2_x1_0(pretrained=False)
    elif model_name == 'mobilenetv2':
        return models.mobilenet_v2(pretrained=False)
    elif model_name == 'mobilenetv3':
        return models.mobilenet_v3_large(pretrained=False)
    else:
        raise ValueError("Model not supported")

def create_new_fast_conv_1(layer_old, in_put):
    for name, param in layer_old.named_parameters():
        # 确定好卷积层中的权重与偏置项，并根据输入的通道数进行裁剪，以便处理
        if str(name) == 'weight':
            conv_weight = param[:, 0:in_put.shape[1], :, :]
        elif str(name) == 'bias':
            conv_bias = param
    # 利用裁剪后的卷积层参数矩阵结合偏置项，处理输入
    out_put = nn.functional.conv2d(in_put, conv_weight, bias=conv_bias, padding=1) # 直接输出？
    return out_put


def check_nan_inf(value):
    if isinstance(value, float):  # Check if the input is a float
        return torch.isnan(torch.tensor([value])).any() or torch.isinf(torch.tensor([value])).any()
    elif torch.is_tensor(value):  # Check if the input is a tensor
        return torch.isnan(value).any() or torch.isinf(value).any()
    else:
        raise ValueError("Unsupported type for check_nan_inf")
modified_mode_list = ['finetune','prune']



LSH_model = DualBranchNet()
# optimizer = optim.Adam(joint_model.parameters(), lr=1e-4, weight_decay=1e-5)
optimizer = optim.Adam(LSH_model.parameters(), lr=1e-4, weight_decay=1e-5)
# optimizer = optim.SGD(LSH_model.parameters(), lr=0.0001, momentum=0.9)
num_epochs = 20  # You can adjust this
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LSH_model.to(device)
for epoch in range(num_epochs):
    #----LSH_model
    LSH_model.train() #----这里怎么改？？？前面那个函数好像没有train和eval函数？马上搜一下？
    # joint_model.train()

    running_loss = 0.0

    # for i, model_ckpt in tqdm(enumerate(clean_models_train)): # 遍历模型
    for i, ori_model_name_index in tqdm(enumerate(model_name_list)): # 遍历模型
        for j, ori_model_dataset_index in tqdm(enumerate(model_dataset_list)): # 遍历数据集

            if random.random() < 0.75:
                if ori_model_dataset_index == 'CIFAR10':
                    cnn_ori_weights_path = f'/data/liuruiheng/MM_code/cifar10_model/ori_{ori_model_name_index}_CIFAR10.pth'
                elif ori_model_dataset_index == 'CIFAR100':
                    cnn_ori_weights_path = f'/data/liuruiheng/MM_code/cifar100_model/ori_{ori_model_name_index}_CIFAR100.pth'
                elif ori_model_dataset_index == 'MNIST':
                    cnn_ori_weights_path = f'/data/liuruiheng/MM_code/MNIST_model/ori_{ori_model_name_index}_MNIST.pth'
                elif ori_model_dataset_index == 'FashionMNIST':
                    cnn_ori_weights_path = f'/data/liuruiheng/MM_code/FashionMNIST_model/ori_{ori_model_name_index}_FashionMNIST.pth'

                num_classes = 100 if ori_model_dataset_index == 'CIFAR100' else 10
                # 先加载模型：
                cnn_clean = get_model(ori_model_name_index).to(device)
                # 权重加载：
                cnn_clean.load_state_dict(torch.load(cnn_ori_weights_path))
                cnn_clean.cuda()
                cnn_clean.eval()

                modified_mode = choice(modified_mode_list)

                if modified_mode == 'finetune':
                    cnn_modified = get_model(ori_model_name_index).to(device)

                    modified_model_name_index = ori_model_name_index
                    modified_model_dataset_index = ori_model_dataset_index
                    epoch_index = random.choice(range(1,11))
                    save_idx_index = random.choice(range(1,6))

                    if modified_model_dataset_index == 'CIFAR10':
                        cnn_modified_weights_path = f'/data/liuruiheng/MM_code/cifar10_model/model_{modified_model_name_index}_{epoch_index}_{save_idx_index}_CIFAR10.pth'
                    elif modified_model_dataset_index == 'CIFAR100':
                        cnn_modified_weights_path = f'/data/liuruiheng/MM_code/cifar100_model/model_{modified_model_name_index}_{epoch_index}_{save_idx_index}_CIFAR100.pth'
                    elif modified_model_dataset_index == 'MNIST':
                        cnn_modified_weights_path = f'/data/liuruiheng/MM_code/MNIST_model/model_{modified_model_name_index}_{epoch_index}_{save_idx_index}_MNIST.pth'
                    elif modified_model_dataset_index == 'FashionMNIST':
                        cnn_modified_weights_path = f'/data/liuruiheng/MM_code/FashionMNIST_model/model_{modified_model_name_index}_{epoch_index}_{save_idx_index}.pth'

                    cnn_modified.load_state_dict(torch.load(cnn_modified_weights_path))

                elif modified_mode == 'prune':
                    ru_list = [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9]
                    ru = random.choice(ru_list)
                    cnn_modified = copy.deepcopy(cnn_clean)
                    prune_model(cnn_modified, amount=ru)

                cnn_modified.cuda()
                cnn_modified.eval()

                lst = range(len(model_name_list))
                poison_model_index = ori_model_name_index
                # poison_model_index = get_random_index(lst, i)
                poison_model_name_index = model_name_list[poison_model_index]

                # poison_model_name_index = random.choice(model_name_list)
                poison_model_dataset_index = random.choice(model_dataset_list)

                if poison_model_dataset_index == 'CIFAR10':
                    cnn_poisoned_weights_path = f'/data/liuruiheng/MM_code/cifar10_model/ori_{poison_model_name_index}_CIFAR10.pth'
                elif poison_model_dataset_index == 'CIFAR100':
                    cnn_poisoned_weights_path = f'/data/liuruiheng/MM_code/cifar100_model/ori_{poison_model_name_index}_CIFAR100.pth'
                elif poison_model_dataset_index == 'MNIST':
                    cnn_poisoned_weights_path = f'/data/liuruiheng/MM_code/MNIST_model/ori_{poison_model_name_index}_MNIST.pth'
                elif poison_model_dataset_index == 'FashionMNIST':
                    cnn_poisoned_weights_path = f'/data/liuruiheng/MM_code/FashionMNIST_model/ori_{poison_model_name_index}_FashionMNIST.pth'

                # 先加载模型：
                p_classes = 100 if poison_model_dataset_index == 'CIFAR100' else 10
                cnn_poisoned = get_model(poison_model_name_index).to(device)
                # 权重加载：
                cnn_poisoned.load_state_dict(torch.load(cnn_poisoned_weights_path))
                cnn_poisoned.cuda()
                cnn_poisoned.eval()

                # # 加载微调模型权重，需要去索引微调模型：
                # # 每个模型先一个？先打开微调模型所在文件夹
                # ##########----对应把模型加上！
                # modified_model_weights = os.path.join('/data-x/g18/chzchen/LSH_detect_backdoor/main/weights_folder/ULP_clean_models/finetune_trainval/','11checkpoint_'+str(i)+'.th')
                # # pdb.set_trace()
                # cnn_modified.load_state_dict(torch.load(modified_model_weights)['state_dict']) 
                # cnn_modified.eval()

                # 送入LSH模型中；
                optimizer.zero_grad()
                beta_c_clean, beta_f_clean, n_clean, l_clean = propocess_fun(cnn_clean)
                beta_c_poisoned, beta_f_poisoned, n_poisoned, l_poisoned = propocess_fun(cnn_poisoned)
                beta_c_modified, beta_f_modified, n_modified, l_modified = propocess_fun(cnn_modified)

                conv_layer = nn.Conv2d(in_channels=800, out_channels=32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)).cuda() # 800,32,3,3
                # 设置动态卷积层：
                if n_clean!=0:
                    gamma_c_clean = create_new_fast_conv_1(conv_layer, torch.unsqueeze(beta_c_clean, dim=0).cuda())
                else:
                    gamma_c_clean = beta_c_clean

                if l_clean!=0:
                    gamma_f_clean = create_new_fast_conv_1(conv_layer, torch.unsqueeze(beta_f_clean, dim=0).cuda())
                else:
                    gamma_f_clean = beta_f_clean
                # print('gamma_f_clean:'+str(gamma_f_clean.shape))

                if n_poisoned!=0:
                    gamma_c_poisoned = create_new_fast_conv_1(conv_layer, torch.unsqueeze(beta_c_poisoned, dim=0).cuda())
                else:
                    gamma_c_poisoned = beta_c_poisoned
                # print('gamma_c_poisoned:'+str(gamma_c_poisoned.shape))

                if l_poisoned!=0:
                    # dynamic_conv_l_poisoned = nn.Conv2d(l_poisoned, out_channels=32, kernel_size=1).cuda()
                    # gamma_f_poisoned = dynamic_conv_l_poisoned(torch.unsqueeze(beta_f_poisoned, dim=0).cuda())
                    gamma_f_poisoned = create_new_fast_conv_1(conv_layer, torch.unsqueeze(beta_f_poisoned, dim=0).cuda())
                else:
                    gamma_f_poisoned = beta_f_poisoned
                # print('gamma_f_poisoned:'+str(gamma_f_poisoned.shape))   

                if n_modified!=0:
                    gamma_c_modified = create_new_fast_conv_1(conv_layer, torch.unsqueeze(beta_c_modified, dim=0).cuda())
                else:
                    gamma_c_modified = beta_c_modified
                # print('gamma_c_modified:'+str(gamma_c_modified.shape))

                if l_modified!=0:
                    gamma_f_modified = create_new_fast_conv_1(conv_layer, torch.unsqueeze(beta_f_modified, dim=0).cuda())
                else:
                    gamma_f_modified = beta_f_modified


                H_o = LSH_model(gamma_c_clean,gamma_f_clean)
                H_p = LSH_model(gamma_c_modified, gamma_f_modified) 
                H_np = LSH_model(gamma_c_poisoned, gamma_f_poisoned) 

               

                mse_loss = torch.mean((H_o - H_p)**2)
                # print(mse_loss)
                O_s = 1 / (1 + torch.exp(-mse_loss))
                # print(O_s)
                mse_loss = torch.mean((H_o - H_np)**2)
                # print(mse_loss)
                O_d = 1 / (1 + torch.exp(-mse_loss))
                # print(O_d)

                # O_s = 1/(1+torch.exp(-mean_squared_error(H_o, H_p)))
                # O_d = 1/(1+torch.exp(-mean_squared_error(H_o, H_np)))
                class_loss = O_s - O_d

                loss = class_loss

                loss.backward()
                optimizer.step()
                running_loss += loss.item()

                #----------
                print('loss:'+str(loss))
                #----------

                # scheduler.step()  # Decrease the learning rate

                # Usage in training loop
                if check_nan_inf(running_loss):
                    print("NaN or Inf in loss, stopping training")
                    break  # or break, based on your handling logic
            else:
                # Branch 2 execution
                # branch_2_count += 1
                print('test_set_modelname:'+str(ori_model_name_index)+' '+str(ori_model_dataset_index))

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/(len(model_name_list)*len(model_dataset_list))}")
# Save the trained model
torch.save(LSH_model.state_dict(), 'train_1_T_loss_whole_dataset_with_pruning_75_trainning_no_lr_drop_CORR_OnlyLSH_new.pth')